Tutorial: Federated learning with websockets and federated averaging with possible solutions for problem you might face

This notebook will discuss detailed steps and problems you might face when going through these steps

Make sure you have correct websocket-client library because if you have another websocket library installed on top of websocket-client when you run this command import websocket it try will access that additional websocket library first because websocket-client is also called imported into your python script by import websocket and when you try to create connection with this command websocket.create_connection() this causes websocket don't have any module named create_connection Solution: in terminal activate that environment where syft is installed run pip uninstall websocket to remove any additional websocket libraries then run pip install --upgrade websocket_client


Preparation: start the websocket server workers

Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

So first, you need to cd to the folder where this notebook and other additional files for running server and client are

for example in windows 10

cd (path till projects directory) \python_projects\websockets-example-MNIST

Note: Don't copy paste the path above because this is purely for the sake example your path may differ depending on your OS and project folder

because if you don't when you try to run python command in terminal this script open sub processes with python which runs other scripts that starts websocket server workers and only the name of the file with its extension is mentioned because the file's path may vary. we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):


Setting up the websocket client workers

We first need to perform the imports and setup some arguments and variables.

In [ ]:
%load_ext autoreload
%autoreload 2

In [ ]:
import sys
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
import torch
from torchvision import datasets, transforms

from syft.frameworks.torch.fl import *

In [ ]:
import run_websocket_client as rwc

In [ ]:
args = rwc.define_and_get_arguments(args=args)
use_cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

Now let's instantiate the websocket client workers, our local access point to the remote workers. Note that this step will fail, if the websocket server workers are not running.

In [ ]:
hook = sy.TorchHook(torch)

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

workers = [alice, bob, charlie]

Prepare and distribute the training data

We will use the MNIST dataset and distribute the data randomly onto the workers. This is not realistic for a federated training setup, where the data would normally already be available at the remote workers.

We instantiate two FederatedDataLoaders, one for the train and one for the test set of the MNIST dataset.

If you run into BrokenPipe errors go to the parrent directory of the directory where your project is and delete data folder then restart notebook and try again if the error comes again delete that data folder again run the following command

for example directory for data

(path till projects directory) \python_projects\

directory for project notebook and scripts

(path till projects directory) \python_projects\websockets-example-MNIST

Note: Don't copy paste the path above because this is purely for the sake example your path may differ depending on your OS and project folder

In [ ]:
#run this box only if the the next box gives pipeline error

In [ ]:
federated_train_loader = sy.FederatedDataLoader(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]

test_loader =
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]

Next, we need to instantiate the machine learning model. It is a small neural network with 2 convolutional and two fully connected layers. It uses ReLU activations and max pooling.

In [ ]:
model = rwc.Net().to(device)

In [ ]:
import logging
import sys
logger = logging.getLogger()
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s")
logger.handlers = [handler]

Let's start the training

Now we are ready to start the federated training. We will perform training over a given number of batches separately on each worker and then calculate the federated average of the resulting model and calculate test accuracy over that model.

In [ ]:
for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = rwc.train(model, device, federated_train_loader,, args.federate_after_n_batches, 
    rwc.test(model, device, test_loader)

